import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 


ds = load_dataset('mlburnham/Pol_NLI')
train = ds['train'].to_pandas()
val = ds['validation'].to_pandas()
test = ds['test'].to_pandas()
alldata = pd.concat([train, val, test])
alldata.reset_index(drop = True, inplace = True)

def describe_dataset(data_dict, split):
    df = data_dict[split].to_pandas()  # Convert the split to a pandas DataFrame

    # Unique datasets
    unique_datasets = df['dataset'].nunique()

    # Unique premises, hypotheses, and augmented hypotheses
    unique_premises = df['premise'].nunique()
    premises = len(df['premise'])
    unique_hypotheses = df['hypothesis'].nunique()
    unique_augmented_hypotheses = df['augmented_hypothesis'].nunique()

    # Average length of premises in terms of word count
    avg_premise_word_count = df['premise'].apply(lambda x: len(x.split())).mean()
    med_premise_word_count = df['premise'].apply(lambda x: len(x.split())).median()

    # Entailment counts
    entailment_counts = df['entailment'].value_counts().to_dict()

    return {
        'unique_datasets': unique_datasets,
        'unique_premises': unique_premises,
        'total_premises': premises,
        'unique_hypotheses': unique_hypotheses,
        'unique_augmented_hypotheses': unique_augmented_hypotheses,
        'avg_premise_word_count': avg_premise_word_count,
        'med_premise_word_count': med_premise_word_count,
        'entailment_counts': entailment_counts
    }

###########
## Table 4
###########
results = alldata.groupby("dataset").agg(
    unique_hypotheses=("hypothesis", "nunique"),
    total_rows=("hypothesis", "count"),
    sum_entailment=("entailment", "sum"),
    task=("task", "first"),
    avg_premise_word_count=("premise", lambda x: x.str.split().apply(len).median())
).reset_index()

# Calculate total documents per task and overall
total_docs_per_task = results.groupby("task")["total_rows"].transform("sum")
total_docs_overall = results["total_rows"].sum()

# Calculate percentages
results["percent_of_task"] = (results["total_rows"] / total_docs_per_task).round(3) * 100
results["percent_of_total"] = (results["total_rows"] / total_docs_overall).round(3) * 100

# Sort by the task column
task_order = ["stance detection", "topic classification", "hatespeech and toxicity", "event extraction"]
results["task"] = pd.Categorical(results["task"], categories=task_order, ordered=True)
results = results.sort_values(by="task")

results.columns = ['Dataset', 'Hypotheses', 'Documents', 'Not-Entail', 'Task', 'Median Word Count', '% of Task', '% of Total']
results['Entail'] = results['Documents'] - results['Not-Entail']
results = results[['Task', 'Dataset', 'Documents', 'Hypotheses', 'Entail', 'Not-Entail',  '% of Task', '% of Total', 'Median Word Count']]

# Save the DataFrame to a text file
output_file = './tables/table_4.txt'
df_string = results.to_string()

with open(output_file, 'w') as f:
    f.write(df_string)

###########
## Table 5
###########
stance = Dataset.from_pandas(alldata[alldata['task'] == 'stance detection'])
event = Dataset.from_pandas(alldata[alldata['task'] == 'event extraction'])
hate = Dataset.from_pandas(alldata[alldata['task'] == 'hatespeech and toxicity'])
topic = Dataset.from_pandas(alldata[alldata['task'] == 'topic classification'])

task_ds = DatasetDict({'stance':stance, 'event':event, 'hate':hate, 'topic':topic})
# Analyze each split
splits = ['stance', 'topic', 'hate', 'event']
results = {}

for split in splits:
    results[split] = describe_dataset(task_ds, split)

# Calculate total unique datasets across all splits
all_datasets = pd.concat([task_ds[split].to_pandas() for split in splits])
total_unique_datasets = all_datasets['dataset'].nunique()
total_documents = all_datasets.shape[0]
total_unique_hypotheses = all_datasets['hypothesis'].nunique()
total_not_entail = all_datasets['entailment'].sum()
total_entail = total_documents - total_not_entail

data_for_df = {
    'Unique Datasets': [results[split]['unique_datasets'] for split in splits],
    'Total Documents': [results[ds]['total_premises'] for ds in splits],
    'Unique Hypotheses': [results[split]['unique_hypotheses'] for split in splits],
    'Median Premise Length': [f"{results[split]['med_premise_word_count']:.2f}" for split in splits],
    'Entail': [results[split]['entailment_counts'].get(0, 0) for split in splits],
    'Not Entail': [results[split]['entailment_counts'].get(1, 0) for split in splits]
}
df = pd.DataFrame(data_for_df, index=splits)

# Add the totals row
df.loc['Total'] = [
    total_unique_datasets,
    total_documents,
    total_unique_hypotheses,
    f"{all_datasets['premise'].apply(lambda x: len(x.split())).median():.2f}",
    total_entail,
    total_not_entail
]

# Save the DataFrame to a text file
output_file = './tables/table_5.txt'
df_string = df.to_string()

with open(output_file, 'w') as f:
    f.write(df_string)

###########
## Table 6
###########
# Analyze each split
splits = ['train', 'validation', 'test']
results = {}

for split in splits:
    results[split] = describe_dataset(ds, split)

# Calculate total unique datasets across all splits
all_datasets = pd.concat([ds[split].to_pandas() for split in splits])
total_unique_datasets = all_datasets['dataset'].nunique()
total_documents = all_datasets.shape[0]
total_unique_hypotheses = all_datasets['hypothesis'].nunique()
total_not_entail = all_datasets['entailment'].sum()
total_entail = total_documents - total_not_entail

data_for_df = {
    'Unique Datasets': [results[split]['unique_datasets'] for split in splits],
    'Total Documents': [results[ds]['total_premises'] for ds in splits],
    'Unique Hypotheses': [results[split]['unique_hypotheses'] for split in splits],
    'Median Premise Length': [f"{results[split]['med_premise_word_count']:.2f}" for split in splits],
    'Entail': [results[split]['entailment_counts'].get(0, 0) for split in splits],
    'Not Entail': [results[split]['entailment_counts'].get(1, 0) for split in splits]
}
df = pd.DataFrame(data_for_df, index=splits)

# Add the totals row
df.loc['Total'] = [
    total_unique_datasets,
    total_documents,
    total_unique_hypotheses,
    f"{all_datasets['premise'].apply(lambda x: len(x.split())).median():.2f}",
    total_entail,
    total_not_entail
]

# Save the DataFrame to a text file
output_file = './tables/table_6.txt'
df_string = df.to_string()

with open(output_file, 'w') as f:
    f.write(df_string)

###########
## Table 7
###########
# Improt and subset
dt_train = load_dataset("mlburnham/Pol_NLI", split="train")
dt_train_df = pd.DataFrame(dt_train)
dt_train_hate = dt_train_df[dt_train_df['task'] == "hatespeech and toxicity"]
dt_train_group = dt_train_hate[dt_train_hate['hypothesis'] != "This text is hate speech."]
# Clean and organize labels
dt_train_group.loc[:, "target"] = dt_train_group["hypothesis"].apply(lambda x: " ".join(x.split()[-2:]))
dt_train_group["target"] = dt_train_group["target"].str.replace(r'\b(their|against|of|attacking|defending|dehumanizing)\b', '', regex=True)
dt_train_group["target"]= dt_train_group["target"].str.replace('.', '', regex=False).str.strip()
dt_train_sum = dt_train_group['target'].value_counts()
dt_train_sum_df = dt_train_sum.reset_index()

dt_train_sum_df = dt_train_sum_df.rename(columns={"target": "subcategory"})
dt_train_sum_df['subcategory'] = dt_train_sum_df['subcategory'].str.title()

rename_map = {
    "Gender": "General Gender",
    "Race": "General Race",
    "Sexuality": "Sexuality Overall",
    "Religion": "General Religion",
    "Origin": "General Origins",
    "The Disabled": "General Disabled",
    "Age": "General Age"
}

dt_train_sum_df["subcategory"] = dt_train_sum_df["subcategory"].replace(rename_map)

category_mapping = {
    "General Gender": "Gender and Sexuality",
    "Genderal Race": "Gender and Sexuality",
    "Men": "Gender and Sexuality",
    "Women": "Gender and Sexuality",
    "Straight People": "Gender and Sexuality",
    "Gay People": "Gender and Sexuality",
    "Bisexual People": "Gender and Sexuality",
    "Lesbians": "Gender and Sexuality",
    "Transgender People": "Gender and Sexuality",
    "Transgender Men": "Gender and Sexuality",
    "Transgender Women": "Gender and Sexuality",
    "Binary People": "Gender and Sexuality",
    "Sexuality Overall": "Gender and Sexuality",

    "General Religion": "Religion",
    "Muslims": "Religion",
    "Christians": "Religion",
    "Atheists": "Religion",
    "Hindus": "Religion",
    "Mormons": "Religion",
    "Buddhists": "Religion",

    "General Race": "Race and Ethnicity",
    "General Ethnicity": "Race and Ethnicity",
    "Blacks": "Race and Ethnicity",
    "Whites": "Race and Ethnicity",
    "Middle Easterners": "Race and Ethnicity",
    "Asians": "Race and Ethnicity",
    "Jews": "Race and Ethnicity",
    "Latinos": "Race and Ethnicity",
    "Native Americans": "Race and Ethnicity",
    "Pacific Islanders": "Race and Ethnicity",

    "General Origins": "Immigration",
    "Immigrants": "Immigration",
    "Undocumented People": "Immigration",
    "Migrant Workers": "Immigration",

    "General Disabled": "Disability",
    "Cognitively Disabled": "Disability",
    "Neurologically Disabled": "Disability",
    "Physically Disabled": "Disability",
    "Visually Impaired": "Disability",
    "Hearing Impaired": "Disability",

    "General Age": "Age",
    "Young Adults": "Age",
    "Children": "Age",
    "Teenagers": "Age",
    "Seniors": "Age",
    "Middle Aged": "Age",

    "Political Outgroups": "Political Outgroups"
}

# Assign category column
dt_train_sum_df["category"] = dt_train_sum_df["subcategory"].map(category_mapping)

# Reorder columns
dt_train_sum_df = dt_train_sum_df[["category", "subcategory", "count"]]

# Define category order
category_order = [
    "Gender and Sexuality",
    "Religion",
    "Race and Ethnicity",
    "Immigration",
    "Disability",
    "Age",
    "Political Outgroups"
]

# Sort by category order, then count descending
dt_train_sum_df["category"] = pd.Categorical(dt_train_sum_df["category"], categories=category_order, ordered=True)
dt_train_sum_df = dt_train_sum_df.sort_values(["category", "count"], ascending=[True, False])

# Reset index
dt_train_sum_df = dt_train_sum_df.reset_index(drop=True)

# Save the DataFrame to a text file
output_file = './tables/table_7.txt'
df_string = dt_train_sum_df.to_string()

with open(output_file, 'w') as f:
    f.write(df_string)